Order of magnitude faster training for image classification: Part II

Transfer learning using Inception Package - Cloud Run Experience

This notebook continues the codifies the capabilities discussed in this blog post. In a nutshell, it uses the pre-trained inception model as a starting point and then uses transfer learning to train it further on additional, customer-specific images. For explanation, simple flower images are used. Compared to training from scratch, the time and costs are drastically reduced.

This notebook does preprocessing, training and prediction by calling CloudML API instead of running them in the Datalab container. The purpose of local work is to do some initial prototyping and debugging on small scale data - often by taking a suitable (say 0.1 - 1%) sample of the full data. The same basic steps can then be repeated with much larger datasets in cloud.

Setup

First run the following steps only if you are running Datalab from your local desktop or laptop (not running Datalab from a GCE VM):

  1. Make sure you have a GCP project which is enabled for Machine Learning API and Dataflow API.
  2. Run "%datalab project set --project [project-id]" to set the default project in Datalab.

If you run Datalab from a GCE VM, then make sure the project of the GCE VM is enabled for Machine Learning API and Dataflow API.


In [4]:
import mltoolbox.image.classification as model
from google.datalab.ml import *

bucket = 'gs://' + datalab_project_id() + '-lab'
preprocess_dir = bucket + '/flowerpreprocessedcloud'
model_dir = bucket + '/flowermodelcloud'
staging_dir = bucket + '/staging'

In [ ]:
!gsutil mb $bucket

Preprocess

Preprocessing uses a Dataflow pipeline to convert the image format, resize images, and run the converted image through a pre-trained model to get the features or embeddings. You can also do this step using alternate technologies like Spark or plain Python code if you like. The %%ml preprocess command simplifies this task. Check out the parameters shown using --usage flag first and then run the command.

If you hit "PERMISSION_DENIED" when running the following cell, you need to enable Cloud DataFlow API (url is shown in error message).

The DataFlow job usually takes about 20 min to complete.


In [2]:
train_set = CsvDataSet('gs://cloud-datalab/sampledata/flower/train1000.csv', schema='image_url:STRING,label:STRING')
preprocess_job = model.preprocess_async(train_set, preprocess_dir, cloud={'num_workers': 10})
preprocess_job.wait() # Alternatively, you can query the job status by train_job.state. The wait() call blocks the notebook execution.


/usr/local/lib/python2.7/dist-packages/apache_beam/coders/typecoders.py:136: UserWarning: Using fallback coder for typehint: Any.
  warnings.warn('Using fallback coder for typehint: %r.' % typehint)
Job "preprocess-image-classification-170304-052813" submitted.

Click here to track preprocessing job.

Train

Note that the command remains the same as that in the "local" version.


In [3]:
train_job = model.train_async(preprocess_dir, 30, 1000, model_dir, cloud=CloudTrainingConfig('us-central1', 'BASIC'))
train_job.wait() # Alternatively, you can query the job status by train_job.state. The wait() call blocks the notebook execution.


Job "image_classification_train_170307_002934" submitted.

Click here to view cloud log.

Out[3]:
Job image_classification_train_170307_002934 completed

Check your job status by running (replace the job id from the one shown above):

Job('image_classification_train_170307_002934').describe()

Tensorboard works too with GCS path. Note that the data will show up usually a minute after tensorboard starts with GCS path.


In [ ]:
tb_id = TensorBoard.start(model_dir)

Predict

Deploy the model and run online predictions. The deployment takes about 2 ~ 5 minutes.


In [21]:
Models().create('flower')
ModelVersions('flower').deploy('beta1', model_dir)


Waiting for operation "projects/bradley-playground/operations/create_flower_beta1-1488494327528"
Done.

Online prediction is currently in alpha, it helps to ensure a warm start if the first call fails.


In [6]:
images = [
    'gs://cloud-ml-data/img/flower_photos/daisy/15207766_fc2f1d692c_n.jpg',
    'gs://cloud-ml-data/img/flower_photos/tulips/6876631336_54bf150990.jpg'
]
# set resize=True to avoid sending large data in prediction request.
model.predict('flower.beta1', images, resize=True, cloud=True)


Predicting...

daisy(0.99997)

tulips(0.99999)

Out[6]:
image_url label score
0 gs://cloud-ml-data/img/flower_photos/daisy/152... daisy 0.999969
1 gs://cloud-ml-data/img/flower_photos/tulips/68... tulips 0.999988

Batch Predict


In [ ]:
import google.datalab.bigquery as bq

bq.Dataset('flower').create()
eval_set = CsvDataSet('gs://cloud-datalab/sampledata/flower/eval670.csv', schema='image_url:STRING,label:STRING')
batch_predict_job = model.batch_predict_async(eval_set, model_dir, output_bq_table='flower.eval_results_full',
                                              cloud={'temp_location': staging_dir})
batch_predict_job.wait()


/usr/local/lib/python2.7/dist-packages/apache_beam/coders/typecoders.py:136: UserWarning: Using fallback coder for typehint: Any.
  warnings.warn('Using fallback coder for typehint: %r.' % typehint)
Job "batch-predict-image-classification-170307-004521" submitted.

Click here to track batch prediction job.


In [1]:
%%bq query --name wrong_prediction

SELECT * FROM flower.eval_results_full WHERE target != predicted

In [2]:
wrong_prediction.execute().result()


Out[2]:
image_urltargetpredictedtarget_probpredicted_prob
gs://cloud-ml-data/img/flower_photos/daisy/530738000_4df7e4786b.jpgdaisyroses0.004326916765420.943398058414
gs://cloud-ml-data/img/flower_photos/daisy/10172636503_21bededa75_n.jpgdaisyroses0.3899288177490.591014802456
gs://cloud-ml-data/img/flower_photos/daisy/391364011_5beaaa1ae2_m.jpgdaisyroses0.03534575179220.558989524841
gs://cloud-ml-data/img/flower_photos/daisy/799964360_7e07a227ea_n.jpgdaisytulips0.08499082922940.870252728462
gs://cloud-ml-data/img/flower_photos/daisy/835750256_3f91a147ef_n.jpgdaisytulips2.95653935609e-050.995988786221
gs://cloud-ml-data/img/flower_photos/daisy/7320089276_87b544e341.jpgdaisydandelion0.00260370341130.806947886944
gs://cloud-ml-data/img/flower_photos/daisy/14088053307_1a13a0bf91_n.jpgdaisydandelion0.02882600948210.759018003941
gs://cloud-ml-data/img/flower_photos/daisy/2635314490_e12d3b0f36_m.jpgdaisydandelion0.4408366978170.461920529604
gs://cloud-ml-data/img/flower_photos/daisy/3337643329_accc9b5426.jpgdaisydandelion0.3536754846570.601798713207
gs://cloud-ml-data/img/flower_photos/daisy/19019544592_b64469bf84_n.jpgdaisydandelion0.08708670735360.883496463299
gs://cloud-ml-data/img/flower_photos/daisy/2611119198_9d46b94392.jpgdaisydandelion0.1524794548750.817649424076
gs://cloud-ml-data/img/flower_photos/daisy/4432271543_01c56ca3a9.jpgdaisysunflowers0.0001753545220710.986194372177
gs://cloud-ml-data/img/flower_photos/daisy/10994032453_ac7f8d9e2e.jpgdaisysunflowers0.2814158201220.483355402946
gs://cloud-ml-data/img/flower_photos/roses/5212885371_fe27c406a2_n.jpgrosesdaisy2.51606866186e-060.627713322639
gs://cloud-ml-data/img/flower_photos/roses/5212877807_a3ddf06a7c_n.jpgrosesdaisy1.07227799617e-060.97923719883
gs://cloud-ml-data/img/flower_photos/roses/475936554_a2b38aaa8e.jpgrosesdaisy0.2152280509470.689969658852
gs://cloud-ml-data/img/flower_photos/roses/15190665092_5c1c37a066_m.jpgrosestulips0.4432539343830.553740203381
gs://cloud-ml-data/img/flower_photos/roses/9337528427_3d09b7012b.jpgrosestulips0.0005722510977650.990702390671
gs://cloud-ml-data/img/flower_photos/roses/4998708839_c53ee536a8_n.jpgrosestulips0.0009761989931580.998490691185
gs://cloud-ml-data/img/flower_photos/roses/17040847367_b54d05bf52.jpgrosestulips0.07164355367420.928229928017
gs://cloud-ml-data/img/flower_photos/roses/5292988046_a10f4b0365_n.jpgrosestulips0.04875777289270.935128629208
gs://cloud-ml-data/img/flower_photos/roses/14414100710_753a36fce9.jpgrosestulips0.2033321112390.791280150414
gs://cloud-ml-data/img/flower_photos/roses/16424992340_c1d9eb72b4.jpgrosestulips0.1096531674270.476714611053
gs://cloud-ml-data/img/flower_photos/roses/12562723334_a2e0a9e3c8_n.jpgrosestulips0.02263347059490.939678549767
gs://cloud-ml-data/img/flower_photos/roses/7345657862_689366e79a.jpgrosestulips0.04762515425680.947963654995

(rows: 82, time: 1.6s, 72KB processed, job: job_03vp-LpZgDa9Ylp4nTOpg6CR-24)

In [5]:
ConfusionMatrix.from_bigquery('flower.eval_results_full').plot()



In [6]:
%%bq query --name accuracy

SELECT
  target,
  SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END) as correct,
  COUNT(*) as total,
  SUM(CASE WHEN target=predicted THEN 1 ELSE 0 END)/COUNT(*) as accuracy
FROM
  flower.eval_results_full
GROUP BY
  target

In [7]:
accuracy.execute().result()


Out[7]:
targetcorrecttotalaccuracy
tulips1111300.853846153846
dandelion1461620.901234567901
sunflowers1221370.890510948905
roses1001190.840336134454
daisy1091220.893442622951

(rows: 5, time: 1.5s, 12KB processed, job: job_Rd8B0IcxRkuoB5tqhjYolDtbv0I)

In [8]:
%%bq query --name logloss

SELECT feature, AVG(-logloss) as logloss, count(*) as count FROM
(
SELECT feature, CASE WHEN correct=1 THEN LOG(prob) ELSE LOG(1-prob) END as logloss
FROM
(
SELECT
target as feature, 
CASE WHEN target=predicted THEN 1 ELSE 0 END as correct,
target_prob as prob
FROM flower.eval_results_full))
GROUP BY feature

In [9]:
FeatureSliceView().plot(logloss)


Clean up


In [ ]:
ModelVersions('flower').delete('beta1')
Models().delete('flower')
!gsutil -m rm -r {preprocess_dir}
!gsutil -m rm -r {model_dir}

In [ ]: